import os
import struct
import logging
import binascii

from textwrap import dedent

import angr
from angr.storage.file import SimFileDescriptorDuplex

from ..enums import CrashInputType
from ..scripter import Scripter
from .actions import RexSendAction

l = logging.getLogger("rex.exploit.exploit")


class ExploitException(Exception):
    pass


class Exploit:
    """
    An Exploit object represents the successful application of an exploit technique to a crash state.

    It contains the logic for extracting an understanding of how to interact with a state in order to reproduce it
    concretely, and then the logic for encoding that understanding as an exploit script.
    """

    def __init__(self, crash, bypasses_nx, bypasses_aslr, target_ip=None):
        """
        :param crash: a crash object which has been modified to exploit a vulnerability
        :param bypasses_nx: does the exploit bypass NX?
        :param bypasses_aslr: does the exploit bypass ASLR?
        :param target_ip_addr: what is the address that we are trying to control? This way we can parameterize it a bit
        """

        self.bypasses_nx = bypasses_nx
        self.bypasses_aslr = bypasses_aslr

        self.crash = crash
        self.binary = crash.binary
        self.project = crash.project

        self.exploit_state = crash.state

        self.target_ip = target_ip

    def _at_syscall(self, path):
        """
        Is the current path at a syscall instruction? Will it make a syscall next step?
        :param path: the path to test
        :return: True if a syscall will be executed next step
        """

        return self.project.factory.block(path.addr, num_inst=1).vex.jumpkind.startswith("Ijk_Sys")

    def reuse_input_constraints(self, sim_inp, idx, inp, state=None):
        # binary search to use input for unconstrained bytes
        if state is None:
            state = self.crash.state

        # end of binary search
        if len(inp) == 0:
            return []
        if len(inp) == 1:
            byte_value = sim_inp.get_byte(idx)
            if not state.solver.satisfiable(extra_constraints=[byte_value == inp[0]]):
                return []
            else:
                return [byte_value == inp[0]]

        # try to add constraints on the whole input
        sym_value = sim_inp.get_bytes(idx, len(inp))
        if state.solver.satisfiable(extra_constraints=[sym_value == inp]):
            return [sym_value == inp]

        # recursive logic
        bound = len(inp)//2
        constraints = self.reuse_input_constraints(sim_inp, idx, inp[:bound]) + \
            self.reuse_input_constraints(sim_inp, idx+bound, inp[bound:])
        return constraints

    def _concretize_input(self):
        constraints = []
        if self.crash.use_crash_input:
            constraints = self.reuse_input_constraints(self.crash.sim_input, 0, self.crash.crash_input)
        for act in self.crash.actions:
            if type(act) == RexSendAction and not act.concrete_data:
                act.concrete_data = self.crash.state.solver.eval(act.sim_data, cast_to=bytes, extra_constraints=constraints)

    def script(self, filename=None, stype='py'):
        self._concretize_input()
        scripter = Scripter(self.crash, stype=stype)
        return scripter.script(filename=filename)

    ####### Deprecated Functions ########

    def dump(self, filename=None):
        """
        default behavior for payload dumper
        """

        # Determine where data goes (stdin/tcp)

        if self.crash.input_type in (CrashInputType.STDIN, CrashInputType.POV_FILE):
            stream = self.crash.state.posix.stdin  # stdin
        elif self.crash.input_type == CrashInputType.TCP:
            # determine which TCP socket file the input is coming from
            sock_fds = [ fd for fd_no, fd in self.crash.state.posix.fd.items() if fd_no not in (0, 1, 2) ] + \
                       [ fd for _, fd in self.crash.state.posix.closed_fds ]
            for simfd in sock_fds:
                if isinstance(simfd, SimFileDescriptorDuplex) and \
                        simfd.read_storage.ident.startswith("aeg_input") and \
                        self.crash.state.solver.eval(simfd.read_pos) > 0:
                    # found it!
                    stream = simfd.read_storage
                    break
            else:
                raise ExploitException("Cannot find the exploit input file descriptor.")

        else:
            raise NotImplementedError("CrashInputType %s is not supported yet." % self.crash.input_type)

        # concretize the payload, if use_crash_input is enabled, use the original crash input
        # for unconstrained bytes
        if self.crash.use_crash_input:
            extra_constraints = self.reuse_input_constraints(self.crash.sim_input, 0, self.crash.crash_input)
            data = stream.concretize(extra_constraints=extra_constraints)
        else:
            data = stream.concretize()

        if filename:
            with open(filename, "wb") as f:
                f.write(data)
        return data

    def _script_get_actions_tcp(self, parameterize_target_ip=False):

        actions = ["p = b''"]

        content = self.dump()

        added_actions = False
        if parameterize_target_ip:
            # HACK: try to just find the target IP in the payload
            # If we find it, replace it with a run-time replacement
            l.debug("Trying to parameterize the target instruction pointer in the payload")

            if not self.target_ip:
                l.warning("Tried to parameterize the target instruction pointer without initializing me with the actual target instruction pointer")
            else:
                arch_bits_to_struct_format = {32: [b"<I", b">I"],
                                              64: [b"<Q", b">Q"]}
                bits = self.exploit_state.arch.bits
                if not bits in arch_bits_to_struct_format:
                    l.warning("Cannot parameterize the target instruction pointer on %d-bit arch", bits)
                else:
                    for fmt in arch_bits_to_struct_format[bits]:
                        ip = struct.pack(fmt, self.target_ip)
                        start = content.find(ip)
                        if start != -1:
                            actions.append("# Before saved IP")
                            actions.append("p += %s" % content[:start])

                            actions.append("# calculate saved IP")
                            actions.append("""p += struct.pack(%s, int(args.target_ip, 16))""" % (fmt,))

                            actions.append("# after saved IP")
                            actions.append("p += %s" % content[start+len(ip):])
                            added_actions = True
                            break

        if not added_actions:
            actions.append("p += %s" % content)
        actions.append("r.send(p)")
        return actions

    def _script_get_actions_stdin(self, parameterize_target_ip=False):
        # FIXME: This method is deprecated. remove it later
        return self._script_get_actions_tcp(parameterize_target_ip=parameterize_target_ip)

    def pov(self):
        '''
        Write out the exploit in DARPA's POV format.

        TODO: No value information is accounted for, this will almost always just cause the register to be the value of 0,
        PC to be 0, or the address to leak to be 0
        TODO: No rerandomization has occured at this time, the POV will act as though the target's randomness is fixed
        TODO: if anyone cares about this, it should be done with a scripter
        '''
        actions = [ ]
        path = self.crash.prev
        s = self.exploit_state

        for a in path.history.actions:
            if not isinstance(a, angr.state_plugins.SimActionData) or \
                    not (a.type.startswith('aeg_input') or a.type.startswith('file_/dev/stdout')):
                        continue

            if a.action == 'write':
                size = s.solver.eval(a.size.ast)
                sval = s.solver.eval(a.data.ast, cast_to=bytes)[:size]
                read_action = dedent("""\
                        <read>
                          <length>{}</length>
                          <data format="hex">{}</data>
                        </read>""").format(size, binascii.hexlify(sval).decode())
                actions.append(read_action)

            if a.action == 'read':
                size = s.solver.eval(a.size.ast)
                sval = s.solver.eval(a.data.ast, cast_to=bytes)[:size]
                write_action = dedent("""\
                        <write>
                          <data format="hex">{}</data>
                        </write>""").format(binascii.hexlify(sval).decode())
                actions.append(write_action)

        body = '\n'.join(actions)

        pov = dedent("""\
                <pov>
                  <cbid>{}</cbid>
                  <replay>
                {}
                  </replay>
                </pov>""").format(os.path.basename(self.binary), body)

        return pov
